# coding=utf-8
"""
Author  : Jane
Contact : xijian@ict.ac.cn
Time    : 2021/4/18 21:05
Desc:
"""
import pandas as pd

# export_labels_4_3 = {11: '其他军事相关', 12: '军事演习', 13: '主官变动', 14: '军事政策', 15: '军力部署',
#                                          16: '区域冲突', 17: '武器研发', 18: '非军事', 19:'主题混乱', 20:'军事待定'}
export_labels_4_18 = {21: '其他军事相关', 22: '军事演习', 23: '主官变动', 24: '军事政策', 25: '军力部署',
                                         26: '区域冲突', 27: '武器研发', 28: '非军事', 29:'主题混乱', 30:'军事待定'}
"""
label2id = {'军事演习': 1, '主官变动': 2, '军事政策': 3, '军力部署': 4, '区域冲突': 5, '武器研发': 6,
            '其他军事相关': 0, '非军事': 7, '主题混乱':8, '军事待定':9}
labels = ['军事演习', '主官变动', '军事政策', '军力部署', '区域冲突', '武器研发', '主题混乱', '军事待定']
"""
label2id = {'军事演习': 0, '军力部署': 1, '区域冲突': 2, '武器研发': 3, '非军事': 4, '其他军事相关': 5, '军事政策': 6, '主官变动': 7}
labels = ['军事演习', '军力部署', '区域冲突', '武器研发', '非军事', '其他军事相关', '军事政策', '主官变动']

def process_labels(df, label_dict, label):
    for l in df[label]:
        print('*' * 20, l, label_dict[l], label2id[label_dict[l]])
        break
    df[label] = [label2id[label_dict[l]] for l in df[label]]
    return df

# df_data_pred = pd.read_csv('split/predict_online_4_1_guanzhu_filter_duplicated_sample500.txt', encoding='UTF-8', sep='\t', index_col=False,
#                       header=0, usecols=[0, 1], names=['content', 'pred'])
df_data_pred = pd.read_csv('split/predict_online_4_1_guanzhu_filter_duplicated_sample500_2.txt', encoding='UTF-8', sep='\t', index_col=False,
                      header=0, usecols=[0, 1], names=['content', 'pred'])

print(df_data_pred.head())

# df_data_labeled = pd.read_csv('checked/project_2_dataset.csv', encoding='UTF-8', sep=',', index_col=False,
#                       header=0, usecols=[3], names=['labeled'])
df_data_labeled = pd.read_csv('checked/project_3_dataset.csv', encoding='UTF-8', sep=',', index_col=False,
                      header=0, usecols=[3], names=['labeled'])
# df_data_labeled = process_labels(df_data_labeled, export_labels_4_3, 'labeled')
df_data_labeled = process_labels(df_data_labeled, export_labels_4_18, 'labeled')
print(df_data_labeled['labeled'].value_counts())
print(df_data_labeled.head())

df_data = pd.concat([df_data_pred, df_data_labeled], axis=1)
print(df_data.head())

print(df_data[df_data.pred==df_data.labeled]['pred'].value_counts())